function fig4_devaluation_analysis(path)

    colors_stages = {hsv2rgb([0, 0.0, 0.85]), hsv2rgb([0, 0.0, 0.7])};
    colors = {hsv2rgb([0, 0.8, 0.4]), hsv2rgb([0.6, 0.4, 0.8])};
    
    rwd = 0;
    ids = 0:49;
    
    episode_ranges = 1:2;
    N = length(episode_ranges); 
    
    total_steps = [140000,  180000, 260000, 420000];
    explains = ["40k steps training", "80k steps training", "160k steps training", "320k steps training"];
    explains_short = ["40k", "80k", "160k", "320k"];
    ratio_right_after_all = [];
    sigma_prior_before_all = [];
    sigma_post_before_all = [];
    sigma_prior_after_all = [];
    sigma_post_after_all = [];
    habit_ratio_after_all = [];
    
    
    figure;
    set(gcf, 'Position', [240, 100, 150 + 350 * length(total_steps) , 800]);
           
    for nn = 1 : length(total_steps)

        details_dir = strcat(path, ...
        sprintf('search_mpz_0.1_s3s_%d_thr_0.05',total_steps(nn)), ...
        '\details\');
       
        
        stages = {'before', 'after'};
        
        
        % ----------------------------------------------
        
        count_left_before = ids - ids;  % 0s
        count_left_after = ids - ids;  % 0s
        count_right_before = ids - ids;
        count_right_after = ids - ids;
        count_failed_before = ids - ids;
        count_failed_after = ids - ids;
        
    
        subplot(3, length(total_steps) + 1, nn + 1)
        draw_tmaze;
        ylim([-12.5, 6.5])
        title(explains{nn})
        
        
        for id = ids
    
            % ----------------------- plot trajctories ------------------
            for k = episode_ranges
                    
                for kk = 2:2  % after devaluation
                    stage = stages{kk};
                
                    filename = sprintf("tmaze_%s_devaluation_%d_episode_%d.mat", stage, id, k);
            
                    data = load(strcat(details_dir,  filename));
            
                    hold on
                    e_len = size(data.info,1);
            
                    x_traj = data.info(1:e_len,1);
                    y_traj = data.info(1:e_len,2);

                    line(x_traj, y_traj, 'Color', [0.3 0.3 0.3 0.25], 'linewidth', 0.6);
                
                end
            
            end
    
            % ----------------------
        
            for k = 1:length(episode_ranges)
        
                for kk = 1:2
                    stage = stages{kk};
                    
                        filename = sprintf("tmaze_%s_devaluation_%d_episode_%d.mat", stage, id, episode_ranges(k));
                        data = load(strcat(details_dir,  filename));
                        
                        if kk == 1 % before devaluation
                            if size(data.info, 1) < 60 && data.info(end, 1) < 0
                                count_left_before(id + 1) =  count_left_before(id + 1) + 1;
                            elseif size(data.info, 1) < 60 && data.info(end, 1) > 0
                                count_right_before(id + 1) =  count_right_before(id + 1) + 1;
                            else
                                count_failed_before(id + 1) =  count_failed_before(id + 1) + 1;
                            end
    
                            sigma_prior_before_all(id + 1, nn, k) = mean(data.model_sig_z_p, 'all');
                            sigma_post_before_all(id + 1, nn, k) = mean(data.model_sig_z_q, 'all');
    
    
                         
                        elseif kk == 2 % after devaluation
                            if size(data.info, 1) < 60 && data.info(end, 1) < 0
                                count_left_after(id + 1) =  count_left_after(id + 1) + 1;
                            elseif size(data.info, 1) < 60 && data.info(end, 1) > 0
                                count_right_after(id + 1) =  count_right_after(id + 1) + 1;
                            else
                                count_failed_after(id + 1) =  count_failed_after(id + 1) + 1;
                            end
                        end
                        
                        sigma_prior_after_all(id + 1, nn, k) = mean(data.model_sig_z_p, 'all');
                        sigma_post_after_all(id + 1, nn, k) = mean(data.model_sig_z_q, 'all');
    
                        habit_ratio_after_all(id + 1, nn, k) = mean(data.model_sig_z_p.^(-2) ./ (data.model_sig_z_q.^(-2) + data.model_sig_z_p.^(-2)), ...
                            "all");
                 end
            
        
            end
            
        
        end
        
        
        
        
        subplot(3, length(total_steps) + 1, length(total_steps) + 2 +nn );
        
        Y = [mean(count_left_before); mean(count_right_before) ; nan; nan ; mean(count_left_after); mean(count_right_after)] / N;
        E = [std(count_left_before); std(count_right_before) ; 0; 0 ; std(count_left_after); std(count_right_after)] / sqrt(length(ids)*N) / N;
%         bar(Y)
        
        p_values = nan(numel(Y), numel(Y));
        

        %Chi-square test
        % Observed data
        n1 = sum(count_left_before);
        N1 = length(episode_ranges) * length(ids);
        n2 = sum(count_left_after);
        N2 = length(episode_ranges) * length(ids);

        [~, p_values(1,5)] = chi2test(n1, N1, n2, N2);
        if isnan(p_values(1,5))
            p_values(1,5) = 1;
        end

        %Chi-square test
        % Observed data
        n1 = sum(count_right_before);
        N1 = length(episode_ranges) * length(ids);
        n2 = sum(count_right_after);
        N2 = length(episode_ranges) * length(ids);

        [~, p_values(2,6)] = chi2test(n1, N1, n2, N2);
        if isnan(p_values(2,6))
            p_values(2,6) = 1;
        end

        % Make P symmetric, by copying the upper triangle onto the lower triangle
        lidx = tril(true(size(p_values)), -1);
        PT = p_values';
        p_values(lidx) = PT(lidx);
        
        bar_handle = superbar(Y, 'E', E, 'P', p_values, 'PLineOffset', 0.4, 'PStarFontSize', 10);
    
        hold on
% 
%         plot(count_left_before - count_left_before + 1, count_left_before / N, 'k.', 'MarkerSize', 5)
%         plot(count_right_before - count_right_before + 2, count_right_before / N, 'k.', 'MarkerSize', 5)
%         plot(count_left_after - count_left_after + 5, count_left_after / N, 'k.', 'MarkerSize', 5)
%         plot(count_right_after - count_right_after + 6, count_right_after / N, 'k.', 'MarkerSize', 5)
%         

        ylim([0, 1.9])
        yticks([0, 0.5, 1])
        
        if nn == 1
            ylabel("exit reaching ratio")
        end

        xticks([1.5,  5.5])
        xlim([0.5, 6.5])
        xticklabels(["pre-dev.", "post-dev."])
        
    
        for iBarSeries = 1:6
            set(bar_handle(iBarSeries), 'FaceColor', colors{2 - mod(iBarSeries, 2)}, 'EdgeColor', 'none'); 
            
        end
        
        if nn == length(total_steps)
            legend(bar_handle(1:2), ["devalued (left)", "valued (right)"])
        end
        
        title({explains{nn}, " "})
    
        box off
    
    
        ratio_right_before_all(:, nn) = count_right_before ./ (count_right_before + count_left_before + count_failed_before);
        ratio_right_after_all(:, nn) = count_right_after ./ (count_right_after + count_left_after + count_failed_after);
    
        
    end
    
    sigma_prior_before_all = mean(sigma_prior_before_all, 3);
    sigma_prior_after_all = mean(sigma_prior_after_all, 3);
    
    sigma_post_before_all = mean(sigma_post_before_all, 3);
    sigma_post_after_all = mean(sigma_post_after_all, 3);
    
    habit_ratio_after_all = mean(habit_ratio_after_all, 3);
    
    %  ---------------------  sigma_prior after dev. --------------------------
    
    for data_type = 1:4
    
        if data_type > 1
            subplot(3,4,8 + data_type)
        else
            subplot(2,4,4 + data_type)
        end
    
    
    if data_type == 1
        data_to_plot = sigma_prior_after_all;
    
    elseif data_type == 2
        data_to_plot = sigma_post_after_all;
    
    elseif data_type == 3
        data_to_plot = habit_ratio_after_all;
    
    elseif data_type == 4
        data_to_plot = ratio_right_after_all - ratio_right_before_all;
    
    end
    
    means = mean(data_to_plot , 1, 'omitnan');
    errors = std(data_to_plot , [], 1, 'omitnan') / sqrt(length(ids)*N);
    p_values = nan(numel(means), numel(means));
    
    if data_type < 4
        
        if length(means) >= 2
            [~, p_values(1,2)] = ttest2(data_to_plot (:,1), data_to_plot (:,2), 'Vartype', 'unequal');
        end
        if length(means) >= 3
            [~, p_values(2,3)] = ttest2(data_to_plot (:,2), data_to_plot (:,3), 'Vartype', 'unequal');
            [~, p_values(1,3)] = ttest2(data_to_plot (:,1), data_to_plot (:,3), 'Vartype', 'unequal');
        end
        if length(means) >= 4
            p_values(1,3) = nan;
            [~, p_values(1,4)] = ttest2(data_to_plot (:,1), data_to_plot (:,4), 'Vartype', 'unequal');
            [~, p_values(3,4)] = ttest2(data_to_plot (:,3), data_to_plot (:,4), 'Vartype', 'unequal');
        end

    else
        N_all = length(episode_ranges) * length(ids);
        if length(means) >= 2
            n1 = floor(N*sum(data_to_plot (:,1)));
            n2 = floor(N*sum(data_to_plot (:,2)));
            [~, p_values(1,2)] = chi2test(n1, N_all, n2, N_all);
        end
        if length(means) >= 3
            n3 = floor(N*sum(data_to_plot (:,3)));
            [~, p_values(2,3)] = chi2test(n2, N_all, n3, N_all);
            [~, p_values(1,3)] = chi2test(n1, N_all, n3, N_all);
        end
        if length(means) >= 4
            n4 = floor(N*sum(data_to_plot (:,4)));
            p_values(1,3) = nan;
            [~, p_values(1,4)] = chi2test(n1, N_all, n4, N_all);
            [~, p_values(3,4)] = chi2test(n3, N_all, n4, N_all);
        end


    end
    
    % Make P symmetric, by copying the upper triangle onto the lower triangle
    lidx = tril(true(size(p_values)), -1);
    PT = p_values';
    p_values(lidx) = PT(lidx);
    
    bar_handle = superbar(means, 'E', errors, 'P', p_values, 'PLineOffset', max(max(means)/2), 'PStarFontSize', 9);
    for iBarSeries = 1:length(bar_handle)
        set(bar_handle(iBarSeries), 'FaceColor', colors_stages{2}, 'EdgeColor', 'none');
    end

    hold on
    
    if data_type < 4
        for i = 1:4
            plot(data_to_plot(:, i) - data_to_plot(:, i) + i, data_to_plot(:, i), 'k.', 'MarkerSize', 6);
        end
    end
    
    box off
    
    if data_type == 1
        ylabel("STD of prior intention")
    elseif data_type == 2
        ylabel("STD of posterior intention")
    elseif data_type == 3
        ylabel("habitual ratio")
    elseif data_type == 4
        ylabel("behavior change rate")
    end
    
    if data_type == 1
        ylim([0, 0.75])
    elseif data_type == 2
        ylim([0, 0.85])
    elseif data_type == 3
        ylim([0, 2])
    elseif data_type == 4
        ylim([0, 0.5])
    end

    
    xticks(1:length(total_steps))
    xticklabels(explains_short)
    xlim([0.5, length(total_steps) + 0.5])

    end

    h = gcf;
    h.Children(14) = h.Children(4);

end

function [chi2stat, p_value] = chi2test(n1, N1, n2, N2)

       %Chi-square test
       % Observed data
       % Pooled estimate of proportion
       p0 = (n1+n2) / (N1+N2);
       % Expected counts under H0 (null hypothesis)
       n10 = N1 * p0;
       n20 = N2 * p0;
       % Chi-square test, by hand
       observed = [n1 N1-n1 n2 N2-n2];
       expected = [n10 N1-n10 n20 N2-n20];
       chi2stat = sum((observed-expected).^2 ./ expected);
       p_value = 1 - chi2cdf(chi2stat,1);
end